#include "kernel_operator.h"
using namespace AscendC;
constexpr int32_t BUFFER_NUM = 2;

template<typename T>
class KernelFastgelu{
public:
    __aicore__ inline KernelFastgelu() {}
    __aicore__ inline void Init(GM_ADDR x, GM_ADDR out,
                              uint32_t blockLength,
                              uint32_t tileNum, uint32_t tileLength,
                              uint32_t lasttileLength, uint32_t formerNum,
                              uint32_t formerLength, uint32_t formertileNum,
                              uint32_t formertileLength,
                              uint32_t formerlasttileLength, uint32_t tailNum,
                              uint32_t tailLength, uint32_t tailtileNum,
                              uint32_t tailtileLength,
                              uint32_t taillasttileLength, uint32_t tilingKey) {
        ASSERT(GetBlockNum() != 0 && "block dim can not be zero!");
        if (tilingKey == 1) {
            this->blockLength = blockLength;
            this->tileNum =
                tileNum ASSERT(tileNum != 0 && "tile num can not be zero!");
            this->tileLength = tileLength / BUFFER_NUM;
            this->lasttileLength = lasttileLength;

            xGm.SetGlobalBuffer((__gm__ T*)x + this->blockLength * GetBlockIdx(),
                                this->blockLength);
            outGm.SetGlobalBuffer(
                (__gm__ T*)out + this->blockLength * GetBlockIdx(),
                this->blockLength);
        }

        if (tilingKey == 2) {
            this->formerNum = formerNum;
            this->formerLength = formerLength;
            this->formertileNum = formertileNum;
            this->formertileLength = formertileLength;
            this->formerlasttileLength = formerlasttileLength;

            this->tailNum = tailNum;
            this->tailLength = tailLength;
            this->tailtileNum = tailtileNum;
            this->tailtileLength = tailtileLength;
            this->taillasttileLength = taillasttileLength;

            if (GetBlockIdx() < this->formerNum) {  //分到大块核的处理
                this->tileLength = this->formertileLength / BUFFER_NUM;
                this->lasttileLength = this->formerlasttileLength;
                this->tileNum = this->formertileNum * BUFFER_NUM;
                xGm.SetGlobalBuffer(
                    (__gm__ T*)x + this->formerLength * GetBlockIdx(),
                    this->formerLength);
                outGm.SetGlobalBuffer(
                    (__gm__ T*)out + this->formerLength * GetBlockIdx(),
                    this->formerLength);
            } else {  //分到小块核的处理，需要处理的数据量比大核少alignNum个
                this->tileLength = this->tailtileLength / BUFFER_NUM;
                this->lasttileLength = this->taillasttileLength;
                this->tileNum = this->tailtileNum * BUFFER_NUM;
                xGm.SetGlobalBuffer(
                    (__gm__ T*)x + this->formerLength * this->formerNum +
                        this->tailLength * (GetBlockIdx() - this->formerNum),
                    this->tailLength);
                outGm.SetGlobalBuffer(
                    (__gm__ T*)out + this->formerLength * this->formerNum +
                        this->tailLength * (GetBlockIdx() - this->formerNum),
                    this->tailLength);
            }
        }

        pipe.InitBuffer(inQueueIN, BUFFER_NUM, this->tileLength * 1 * sizeof(T));
        pipe.InitBuffer(outQueueOUT, BUFFER_NUM, this->tileLength * sizeof(T));
    
        pipe.InitBuffer(inQueueC, this->tileLength * sizeof(T));
        pipe.InitBuffer(tmpQ, this->tileLength * sizeof(T));

        // pipe.InitBuffer(param1Queue, this->tileLength * sizeof(T));
        // pipe.InitBuffer(param2Queue, this->tileLength * sizeof(T));
        // pipe.InitBuffer(param3Queue, this->tileLength * sizeof(T));
    }

    __aicore__ inline void Process() {
        int32_t loopCount = this->tileNum * BUFFER_NUM;
        for (int32_t i = 0; i < loopCount; i++) {
            CopyIn(i);
            Compute(i);
            CopyOut(i);
        }
    }

private:
    __aicore__ inline void CopyIn(int32_t progress) {
        LocalTensor<T> inLocal = inQueueIN.AllocTensor<T>();
        if (BUFFER_NUM == 1) {
            if (progress == this->tileNum - 1) {
                if (progress == 0) {
                //如果只有一包，则搬运的起始地址为0，tileLength为实际分块的数据量
                DataCopy(inLocal[0], xGm[0], this->tileLength);
                } else {
                //将最后一个分块的起始地址向前移动tileLength-lasttileLength
                DataCopy(
                    inLocal[0],
                    xGm[(progress - 1) * this->tileLength + this->lasttileLength],
                    this->tileLength);
                }
            } else {
                DataCopy(inLocal[0], xGm[progress * this->tileLength],
                        this->tileLength);
            }
        }
        if (BUFFER_NUM == 2) {
            //开启double
            //buffer时，由于将输入数据分成了相等的2部分，分块大小为不开启double
            //buffer的一半， 所以需要对最后两个分块数据的起始地址做处理
            if ((progress == (this->tileNum * BUFFER_NUM - 2)) ||
                (progress == (this->tileNum * BUFFER_NUM - 1))) {
                //分块大小变为tileLength的一半
                //倒数第2个分块数据的起始地址向前移动（tileLength-lasttileLength)，最后一个分块的起始地址以此为基础进行移动
                DataCopy(
                    inLocal[0],
                    xGm[(progress - 2) * (this->tileLength) + this->lasttileLength],
                    (this->tileLength));
            }

            else {
                DataCopy(inLocal[0], xGm[progress * (this->tileLength)],
                        (this->tileLength));
            }
        }

        inQueueIN.EnQue(inLocal);
    }

    __aicore__ inline void Compute(int32_t progress) {
        LocalTensor<T> yLocal = outQueueOUT.AllocTensor<T>();
        LocalTensor<T> xLocal = inQueueIN.DeQue<T>();

        LocalTensor<T> cLocal = inQueueC.Get<T>();
        Abs(cLocal, xLocal, this->tileLength);

        // LocalTensor<T> param1 = param1Queue.Get<T>();
        // for(int i = 0; i < this->tileLength; i++){
        //     param1.SetValue(i, T(-1.702));
        // }
        Muls(cLocal, cLocal, (param1), this->tileLength);        

        Exp(cLocal, cLocal, this->tileLength);
        // LocalTensor<T> param2 = param2Queue.Get<T>();
        // for(int i = 0; i < this->tileLength; i++){
        //     param2.SetValue(i, T(1.0));
        // }   
        Adds(cLocal, cLocal, (param3), this->tileLength); 

        Div(cLocal, xLocal, cLocal, this->tileLength);

        LocalTensor<T> tmp = tmpQ.Get<T>();
        Abs(tmp, xLocal, this->tileLength);
        Sub(tmp, xLocal, tmp, this->tileLength);

        // LocalTensor<T> param3 = param3Queue.Get<T>();
        // for(int i = 0; i < this->tileLength; i++){
        //     param3.SetValue(i, T(0.851));
        // }  
        Muls(tmp, tmp, (param2), this->tileLength);
        Exp(tmp, tmp, this->tileLength);

        Mul(yLocal, cLocal, tmp, this->tileLength);
        outQueueOUT.EnQue<T>(yLocal);
        inQueueIN.FreeTensor(xLocal);
    }

    __aicore__ inline void CopyOut(int32_t progress) {
        LocalTensor<T> outLocal = outQueueOUT.DeQue<T>();
        if (BUFFER_NUM == 1) {
            if (progress == this->tileNum - 1) {
                if (progress == 0) {
                //如果只有一包，则搬运的起始地址为0，tileLength为实际分块的数据量
                DataCopy(outGm[0], outLocal, this->tileLength);
                } else {
                //将最后一个分块的起始地址向前移动tileLength-lasttileLength
                DataCopy(
                    outGm[(progress - 1) * this->tileLength + this->lasttileLength],
                    outLocal, this->tileLength);
                }
            } else {
                DataCopy(outGm[progress * this->tileLength], outLocal,
                        this->tileLength);
            }
        }
        if (BUFFER_NUM == 2) {
        //开启double
        //buffer时，由于将输入数据分成了相等的2部分，分块大小为不开启double
        //buffer的一半， 所以需要对最后两个分块数据的起始地址做处理
        if ((progress == (this->tileNum * BUFFER_NUM - 2)) ||
            (progress == (this->tileNum * BUFFER_NUM - 1))) {
            //分块大小变为tileLength的一半
            //倒数第2个分块数据的起始地址向前移动（tileLength-lasttileLength)，最后一个分块的起始地址以此为基础进行移动
            DataCopy(
                outGm[(progress - 2) * (this->tileLength) + this->lasttileLength],
                outLocal, (this->tileLength));
        }

        else {
            DataCopy(outGm[progress * (this->tileLength)], outLocal,
                    (this->tileLength));
        }
        }

        outQueueOUT.FreeTensor(outLocal);
    }

private:
    TPipe pipe;
    // TQue<QuePosition::VECIN, BUFFER_NUM> inQueueX, inQueueY, inQueueZ;
    TQue<QuePosition::VECIN, BUFFER_NUM> inQueueIN;
    TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueOUT;
    TBuf<QuePosition::VECCALC> inQueueC, tmpQ;
    // TBuf<QuePosition::VECCALC> param1Queue, param2Queue, param3Queue;
    
    GlobalTensor<T> xGm;
    GlobalTensor<T> outGm;
    uint32_t blockLength;
    uint32_t tileNum;
    uint32_t tileLength;
    uint32_t lasttileLength;
    uint32_t formerNum;
    uint32_t formerLength;
    uint32_t formertileNum;
    uint32_t formertileLength;
    uint32_t formerlasttileLength;
    uint32_t tailNum;
    uint32_t tailLength;
    uint32_t tailtileNum;
    uint32_t tailtileLength;
    uint32_t taillasttileLength;

    T param1 = -1.702;
    T param2 = 0.851;
    T param3 = 1.0;

};



extern "C" __global__ __aicore__ void fast_gelu(GM_ADDR x, GM_ADDR out,
                            GM_ADDR workspace, GM_ADDR tiling) {

    GET_TILING_DATA(tiling_data, tiling);
    // TODO: user kernel impl
    if(tiling_data.datatype == 1){ //fp16 half 
        KernelFastgelu<half> op;
        uint32_t tilingKey = 1;
        if (TILING_KEY_IS(1)) {
            tilingKey = 1;
        } else if (TILING_KEY_IS(2)) {
            tilingKey = 2;
        } else {
            tilingKey = 1;
        }
        op.Init(x, out, tiling_data.blockLength,
                tiling_data.tileNum, tiling_data.tileLength,
                tiling_data.lasttileLength, tiling_data.formerNum,
                tiling_data.formerLength, tiling_data.formertileNum,
                tiling_data.formertileLength, tiling_data.formerlasttileLength,
                tiling_data.tailNum, tiling_data.tailLength, tiling_data.tailtileNum,
                tiling_data.tailtileLength, tiling_data.taillasttileLength,
                tilingKey);
        op.Process();
    }else if(tiling_data.datatype == 0){ //fp32
        KernelFastgelu<float> op;
        uint32_t tilingKey = 1;
        if (TILING_KEY_IS(1)) {
            tilingKey = 1;
        } else if (TILING_KEY_IS(2)) {
            tilingKey = 2;
        } else {
            tilingKey = 1;
        }

        op.Init(x, out, tiling_data.blockLength,
                tiling_data.tileNum, tiling_data.tileLength,
                tiling_data.lasttileLength, tiling_data.formerNum,
                tiling_data.formerLength, tiling_data.formertileNum,
                tiling_data.formertileLength, tiling_data.formerlasttileLength,
                tiling_data.tailNum, tiling_data.tailLength, tiling_data.tailtileNum,
                tiling_data.tailtileLength, tiling_data.taillasttileLength,
                tilingKey);
        op.Process();
    }      
}

#ifndef __CCE_KT_TEST__
// call of kernel function
void fast_gelu_do(uint32_t blockDim, void* l2ctrl, void* stream,
                       uint8_t* x, uint8_t* out,
                       uint8_t* workspace, uint8_t* tiling) {
  fast_gelu<<<blockDim, l2ctrl, stream>>>(x, out, workspace, tiling);
}
#endif


